{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "aca4485f",
   "metadata": {},
   "source": [
    "# Simple classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d0aa2e71",
   "metadata": {},
   "outputs": [],
   "source": [
    "%run -i \"../util/file_utils.ipynb\"\n",
    "%run -i \"../util/lang_utils.ipynb\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "46078ddf",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset rotten_tomatoes (/home/zhenya/.cache/huggingface/datasets/rotten_tomatoes/default/1.0.0/40d411e45a6ce3484deed7cc15b82a53dad9a72aafd9f86f8f227134bec5ca46)\n",
      "Found cached dataset rotten_tomatoes (/home/zhenya/.cache/huggingface/datasets/rotten_tomatoes/default/1.0.0/40d411e45a6ce3484deed7cc15b82a53dad9a72aafd9f86f8f227134bec5ca46)\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "train_dataset = load_dataset(\"rotten_tomatoes\", split=\"train[:15%]+train[-15%:]\")\n",
    "test_dataset = load_dataset(\"rotten_tomatoes\", split=\"test[:15%]+test[-15%:]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "945786df",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2560\n",
      "320\n"
     ]
    }
   ],
   "source": [
    "print(len(train_dataset))\n",
    "print(len(test_dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5c08343b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class POS_vectorizer:\n",
    "    def __init__(self, spacy_model):\n",
    "        self.model = spacy_model\n",
    "    \n",
    "    def vectorize(self, input_text):\n",
    "        doc = self.model(input_text)\n",
    "        vector = []\n",
    "        vector.append(len(doc))\n",
    "        pos = {\"VERB\":0, \"NOUN\":0, \"PROPN\":0, \"ADJ\":0, \"ADV\":0, \"AUX\":0, \"PRON\":0, \"NUM\":0, \"PUNCT\":0}\n",
    "        for token in doc:\n",
    "            if token.pos_ in pos.keys():\n",
    "                pos[token.pos_] += 1\n",
    "        vector_values = list(pos.values())\n",
    "        vector = vector + vector_values\n",
    "        return vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "68623062",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_text = train_dataset[0][\"text\"]\n",
    "vectorizer = POS_vectorizer(small_model)\n",
    "vector = vectorizer.vectorize(sample_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "54b2086e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the rock is destined to be the 21st century's new \" conan \" and that he's going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .\n",
      "[38, 3, 8, 3, 4, 1, 3, 1, 0, 5]\n"
     ]
    }
   ],
   "source": [
    "print(sample_text)\n",
    "print(vector)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "30338ada",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "train_df = train_dataset.to_pandas()\n",
    "train_df.sample(frac=1)\n",
    "test_df = test_dataset.to_pandas()\n",
    "train_df[\"vector\"] = train_df[\"text\"].apply(lambda x: vectorizer.vectorize(x))\n",
    "test_df[\"vector\"] = test_df[\"text\"].apply(lambda x: vectorizer.vectorize(x))\n",
    "X_train = np.stack(train_df[\"vector\"].values, axis=0)\n",
    "X_test = np.stack(test_df[\"vector\"].values, axis=0)\n",
    "y_train = train_df[\"label\"].to_numpy()\n",
    "y_test = test_df[\"label\"].to_numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "22702397",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import classification_report\n",
    "clf = LogisticRegression(C=0.1)\n",
    "clf = clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "76ed3170",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.59      0.54      0.56       160\n",
      "           1       0.57      0.62      0.60       160\n",
      "\n",
      "    accuracy                           0.58       320\n",
      "   macro avg       0.58      0.58      0.58       320\n",
      "weighted avg       0.58      0.58      0.58       320\n",
      "\n"
     ]
    }
   ],
   "source": [
    "test_df[\"prediction\"] = test_df[\"vector\"].apply(lambda x: clf.predict([x])[0])\n",
    "print(classification_report(test_df[\"label\"], test_df[\"prediction\"]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
